{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how flow matching works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## Mixture of Gaussians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MoG(nn.Module):\n",
    "    def __init__(self, D, K, uniform=False):\n",
    "        super(MoG, self).__init__()\n",
    "\n",
    "        print('MoG by JT.')\n",
    "        \n",
    "        # hyperparams\n",
    "        self.uniform = uniform\n",
    "        self.D = D  # the dimensionality of the input\n",
    "        self.K = K  # the number of components\n",
    "        \n",
    "        # params\n",
    "        self.mu = nn.Parameter(torch.randn(1, self.K, self.D) * 0.25 + 0.5)\n",
    "        self.log_var = nn.Parameter(-3. * torch.ones(1, self.K, self.D))\n",
    "        \n",
    "        if self.uniform:\n",
    "            self.w = torch.zeros(1, self.K)\n",
    "            self.w.requires_grad = False\n",
    "        else:\n",
    "            self.w = nn.Parameter(torch.zeros(1, self.K))\n",
    "\n",
    "        # other\n",
    "        self.PI = torch.from_numpy(np.asarray(np.pi))\n",
    "    \n",
    "    def log_diag_normal(self, x, mu, log_var, reduction='sum', dim=1):\n",
    "        log_p = -0.5 * torch.log(2. * self.PI) - 0.5 * log_var - 0.5 * torch.exp(-log_var) * (x.unsqueeze(1) - mu)**2.\n",
    "        return log_p\n",
    "    \n",
    "    def forward(self, x, reduction='mean'):\n",
    "        # calculate components\n",
    "        log_pi = torch.log(F.softmax(self.w, 1))  # B x K, softmax is used for R^K -> [0,1]^K s.t. sum(pi) = 1\n",
    "        log_N = torch.sum(self.log_diag_normal(x, self.mu, self.log_var), 2)  # B x K, log-diag-Normal for K components\n",
    "        \n",
    "        # =====LOSS: Negative Log-Likelihood\n",
    "        NLL_loss = -torch.logsumexp(log_pi + log_N,  1)  # B\n",
    "        \n",
    "        # Final LOSS\n",
    "        if reduction == 'sum':\n",
    "            return NLL_loss.sum()\n",
    "        elif reduction == 'mean':\n",
    "            return NLL_loss.mean()\n",
    "        else:\n",
    "            raise ValueError('Either `sum` or `mean`.')\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        # init an empty tensor\n",
    "        x_sample = torch.empty(batch_size, self.D)\n",
    "        \n",
    "        # sample components\n",
    "        pi = F.softmax(self.w, 1)  # B x K, softmax is used for R^K -> [0,1]^K s.t. sum(pi) = 1\n",
    "                             \n",
    "        indices = torch.multinomial(pi, batch_size, replacement=True).squeeze()\n",
    "        \n",
    "        for n in range(batch_size):\n",
    "            indx = indices[n]  # pick the n-th component\n",
    "            x_sample[n] = self.mu[0,indx] + torch.exp(0.5*self.log_var[0,indx]) * torch.randn(self.D)\n",
    "        \n",
    "        return x_sample\n",
    "    \n",
    "    def log_prob(self, x, reduction='mean'):\n",
    "        with torch.no_grad():\n",
    "            # calculate components\n",
    "            log_pi = torch.log(F.softmax(self.w, 1))  # B x K, softmax is used for R^K -> [0,1]^K s.t. sum(pi) = 1\n",
    "            log_N = torch.sum(self.log_diag_normal(x, self.mu, self.log_var), 2)  # B x K, log-diag-Normal for K components\n",
    "        \n",
    "            # log_prob\n",
    "            log_prob = torch.logsumexp(log_pi + log_N,  1)  # B\n",
    "            \n",
    "            if reduction == 'sum':\n",
    "                return log_prob.sum()\n",
    "            elif reduction == 'mean':\n",
    "                return log_prob.mean()\n",
    "            else:\n",
    "                raise ValueError('Either `sum` or `mean`.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = -model_best.log_prob(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}')\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    with torch.no_grad():\n",
    "        # GENERATIONS-------\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "        num_x = 4\n",
    "        num_y = 4\n",
    "        x = model_best.sample(batch_size=num_x * num_y)\n",
    "        x = x.detach().numpy()\n",
    "\n",
    "        fig, ax = plt.subplots(num_x, num_y)\n",
    "        for i, ax in enumerate(ax.flatten()):\n",
    "            plottable_image = np.reshape(x[i], (8, 8))\n",
    "            ax.imshow(plottable_image, cmap='gray')\n",
    "            ax.axis('off')\n",
    "\n",
    "        plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "        plt.close()\n",
    "\n",
    "def plot_curve(name, nll_val):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "def means_save(name, extra_name='', num_x = 4, num_y = 4):\n",
    "    with torch.no_grad():\n",
    "        # GENERATIONS-------\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "        pi = F.softmax(model_best.w, 1).squeeze()\n",
    "\n",
    "        x = model_best.mu[:, 0:num_x * num_y]\n",
    "        N = x.shape[1]\n",
    "        x = x.squeeze(0).detach().numpy()\n",
    "\n",
    "        fig, ax = plt.subplots(int(np.sqrt(N)), int(np.sqrt(N)))\n",
    "        for i, ax in enumerate(ax.flatten()):\n",
    "            plottable_image = np.reshape(x[i], (8, 8))\n",
    "            ax.imshow(plottable_image, cmap='gray')\n",
    "            ax.set_title(f'$\\pi$ = {pi[i].item():.5f}')\n",
    "            ax.axis('off')\n",
    "        fig.tight_layout()\n",
    "        plt.savefig(name + '_means_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "        \n",
    "        samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "        \n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "\n",
    "    return nll_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = tt.Lambda(lambda x: (x/17.) + (np.random.randn(*x.shape)/136.))  # changing to [-1, 1] and adding small Gaussian noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms)\n",
    "val_data = Digits(mode='val', transforms=transforms)\n",
    "test_data = Digits(mode='test', transforms=transforms)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=32, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "D = 64   # input dimension\n",
    "\n",
    "K = 25  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "name = 'mog' + '_' + str(K)\n",
    "if not (os.path.exists('results/')):\n",
    "    os.mkdir(result_dir)\n",
    "result_dir = 'results/' + name + '/'\n",
    "if not (os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "# Eventually, we initialize the full model\n",
    "model = MoG(D=D, K=K, uniform=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "# optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad == True], lr=lr, momentum=0.1, weight_decay=1.e-4)\n",
    "optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "samples_generated(result_dir + name, test_loader, extra_name='FINAL')\n",
    "\n",
    "means_save(result_dir + name, extra_name='_'+str(K), num_x=5, num_y=5)\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
